from plat.plat_train import get_train_info, plat_train_val, plat_trains

"""
比较层数大小对最终结果的影响
网络层数分别取2, 4, 6, 8, 12
"""
if __name__ == '__main__':
    log_path = "../log/GatedGCN/layer/"
    L_2 = "L_2.log"
    L_4 = "L_4.log"
    L_6 = "L_6.log"
    L_8 = "L_8.log"
    L_12 = "L_12.log"
    train_dict = dict()
    val_dict = dict()
    train_dict['L_2'], val_dict['L_2'] = get_train_info(log_path + L_2)
    train_dict['L_4'], val_dict['L_4'] = get_train_info(log_path + L_4)
    train_dict['L_6'], val_dict['L_6'] = get_train_info(log_path + L_6)
    train_dict['L_8'], val_dict['L_8'] = get_train_info(log_path + L_8)
    train_dict['L_12'], val_dict['L_12'] = get_train_info(log_path + L_12)
    plat_train_val('L_2', train_dict['L_2'], val_dict['L_2'])
    plat_train_val('L_4', train_dict['L_4'], val_dict['L_4'])
    plat_train_val('L_6', train_dict['L_6'], val_dict['L_6'])
    plat_train_val('L_8', train_dict['L_8'], val_dict['L_8'])
    plat_train_val('L_12', train_dict['L_12'], val_dict['L_12'])
    plat_trains(train_dict)
    plat_trains(val_dict)
